#!/usr/bin/env python3
"""
1. Joint F-test: Are high-income Z×KAOPEN interactions jointly zero?
2. GE clearing decomposition: Why does rate absorption drop from 45-67% to 6-12%?
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path
from scipy import stats

FOLLOWUP_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/followup")
PROJECT_DIR = FOLLOWUP_DIR.parent
sys.path.insert(0, str(FOLLOWUP_DIR))
sys.path.insert(1, str(PROJECT_DIR))

from src.model import PanelGLS
from src.macro import EBA_COUNTRIES, SSA_COUNTRIES, EU_EXPANSION, EXPANSION_TIER1, filter_eba_sample

OUTPUT_DIR = FOLLOWUP_DIR / "output" / "tables"
PROCESSED_DIR = FOLLOWUP_DIR / "data" / "processed"

panel = pd.read_csv(PROCESSED_DIR / "full_panel.csv")
polys = pd.read_csv(PROJECT_DIR / "data" / "processed" / "demographic_polynomials.csv")

est = panel[
    (panel['ca_gdp'].notna()) &
    (panel['year'] >= 1986) &
    (panel['year'] <= 2024)
].copy()

full_sample = filter_eba_sample(est, extended=True, expansion=True)

# ═══════════════════════════════════════════════════════════════════════════
# 1. JOINT F-TESTS FOR HIGH-INCOME Z×KAOPEN
# ═══════════════════════════════════════════════════════════════════════════
print("=" * 70)
print("JOINT F-TESTS: HIGH-INCOME Z×KAOPEN INTERACTIONS")
print("=" * 70)

# Income classification (same as phase4cd)
if 'gdp_pc_ppp' in full_sample.columns:
    gdp_pc = full_sample.groupby('iso3')['gdp_pc_ppp'].median()
else:
    full_sample['gdp_pc_approx'] = full_sample['ngdp_usd'] / full_sample['population_weo']
    gdp_pc = full_sample.groupby('iso3')['gdp_pc_approx'].median()

q33 = gdp_pc.quantile(0.33)
q67 = gdp_pc.quantile(0.67)
income_map = {}
for iso, val in gdp_pc.items():
    if pd.isna(val):
        income_map[iso] = 'middle'
    elif val >= q67:
        income_map[iso] = 'high'
    elif val >= q33:
        income_map[iso] = 'middle'
    else:
        income_map[iso] = 'low'

full_sample['income_group'] = full_sample['iso3'].map(income_map)

# Build interaction variables
full_sample['is_high'] = (full_sample['income_group'] == 'high').astype(float)
full_sample['is_middle'] = (full_sample['income_group'] == 'middle').astype(float)
full_sample['is_low'] = (full_sample['income_group'] == 'low').astype(float)

for z in ['Z_1', 'Z_2', 'Z_3']:
    for grp in ['high', 'middle', 'low']:
        full_sample[f'{z}_x_kaopen_x_{grp}'] = full_sample[z] * full_sample['kaopen'] * full_sample[f'is_{grp}']

demo_vars = ['Z_1', 'Z_2', 'Z_3']
control_vars = ['fiscal_bal_gdp', 'kaopen', 'expected_growth', 'nfa_gdp_lag',
                'log_rel_opw', 'health_exp_gdp']
rate_var = ['log_lending_rate']
three_way_vars = [f'{z}_x_kaopen_x_{g}' for z in ['Z_1', 'Z_2', 'Z_3'] for g in ['high', 'middle', 'low']]

ext_base = demo_vars + control_vars + rate_var + ['Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']
comp = full_sample.dropna(subset=ext_base + ['ca_gdp']).copy()

# --- Test A: Are high-income interactions jointly zero in the three-way model? ---
# Restricted model: three-way but with high-income interactions constrained to zero
# = model with only middle and low interactions
print("\n--- Test A: Joint significance of HIGH-income Z×KAOPEN ---")
vars_restricted_no_high = demo_vars + control_vars + rate_var + [
    f'{z}_x_kaopen_x_{g}' for z in ['Z_1', 'Z_2', 'Z_3'] for g in ['middle', 'low']
]
vars_unrestricted = demo_vars + control_vars + rate_var + three_way_vars

y = comp['ca_gdp'].values

gls_r = PanelGLS()
gls_r.fit(y, comp[vars_restricted_no_high].values, comp['iso3'].values, comp['year'].values)

gls_u = PanelGLS()
gls_u.fit(y, comp[vars_unrestricted].values, comp['iso3'].values, comp['year'].values)

q = 3  # 3 high-income interaction terms
k_u = len(vars_unrestricted)
n = gls_u.n_obs
F_high = ((gls_u.r_squared - gls_r.r_squared) / q) / ((1 - gls_u.r_squared) / (n - k_u - 1))
p_high = 1 - stats.f.cdf(F_high, q, n - k_u - 1)
print(f"  Restricted (no high interactions): R² = {gls_r.r_squared:.4f}")
print(f"  Unrestricted (all interactions):   R² = {gls_u.r_squared:.4f}")
print(f"  F({q},{n-k_u-1}) = {F_high:.3f}, p = {p_high:.4f}")
if p_high > 0.05:
    print(f"  >>> CANNOT REJECT that high-income interactions are jointly zero (p={p_high:.3f})")
else:
    print(f"  >>> High-income interactions ARE jointly significant (p={p_high:.3f})")

# --- Test B: Are middle-income interactions jointly zero? ---
print("\n--- Test B: Joint significance of MIDDLE-income Z×KAOPEN ---")
vars_restricted_no_mid = demo_vars + control_vars + rate_var + [
    f'{z}_x_kaopen_x_{g}' for z in ['Z_1', 'Z_2', 'Z_3'] for g in ['high', 'low']
]
gls_r2 = PanelGLS()
gls_r2.fit(y, comp[vars_restricted_no_mid].values, comp['iso3'].values, comp['year'].values)

F_mid = ((gls_u.r_squared - gls_r2.r_squared) / q) / ((1 - gls_u.r_squared) / (n - k_u - 1))
p_mid = 1 - stats.f.cdf(F_mid, q, n - k_u - 1)
print(f"  Restricted (no middle interactions): R² = {gls_r2.r_squared:.4f}")
print(f"  F({q},{n-k_u-1}) = {F_mid:.3f}, p = {p_mid:.4f}")

# --- Test C: Are low-income interactions jointly zero? ---
print("\n--- Test C: Joint significance of LOW-income Z×KAOPEN ---")
vars_restricted_no_low = demo_vars + control_vars + rate_var + [
    f'{z}_x_kaopen_x_{g}' for z in ['Z_1', 'Z_2', 'Z_3'] for g in ['high', 'middle']
]
gls_r3 = PanelGLS()
gls_r3.fit(y, comp[vars_restricted_no_low].values, comp['iso3'].values, comp['year'].values)

F_low = ((gls_u.r_squared - gls_r3.r_squared) / q) / ((1 - gls_u.r_squared) / (n - k_u - 1))
p_low = 1 - stats.f.cdf(F_low, q, n - k_u - 1)
print(f"  Restricted (no low interactions): R² = {gls_r3.r_squared:.4f}")
print(f"  F({q},{n-k_u-1}) = {F_low:.3f}, p = {p_low:.4f}")

# --- Test D: High-income only subsample - are Z×KAOPEN significant at all? ---
print("\n--- Test D: Z×KAOPEN on HIGH-income subsample only ---")
high_only = comp[comp['income_group'] == 'high'].copy()
print(f"  High-income subsample: {high_only['iso3'].nunique()} countries, {len(high_only)} obs")

vars_high_sub = demo_vars + control_vars + rate_var
vars_high_sub_ext = vars_high_sub + ['Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']

gls_h_base = PanelGLS()
gls_h_base.fit(high_only['ca_gdp'].values, high_only[vars_high_sub].values,
               high_only['iso3'].values, high_only['year'].values)

gls_h_ext = PanelGLS()
gls_h_ext.fit(high_only['ca_gdp'].values, high_only[vars_high_sub_ext].values,
              high_only['iso3'].values, high_only['year'].values)

n_h = gls_h_ext.n_obs
k_h = len(vars_high_sub_ext)
F_h = ((gls_h_ext.r_squared - gls_h_base.r_squared) / 3) / ((1 - gls_h_ext.r_squared) / (n_h - k_h - 1))
p_h = 1 - stats.f.cdf(F_h, 3, n_h - k_h - 1)
print(f"  Baseline R² = {gls_h_base.r_squared:.4f}")
print(f"  + Z×KAOPEN R² = {gls_h_ext.r_squared:.4f}")
print(f"  F(3,{n_h-k_h-1}) = {F_h:.3f}, p = {p_h:.4f}")

for i, v in enumerate(vars_high_sub_ext):
    if 'x_kaopen' in v:
        sig = '***' if gls_h_ext.pvalues[i] < 0.001 else '**' if gls_h_ext.pvalues[i] < 0.01 else '*' if gls_h_ext.pvalues[i] < 0.05 else ''
        print(f"  {v}: {gls_h_ext.beta[i]:.3f} (p={gls_h_ext.pvalues[i]:.4f}) {sig}")

# --- Test E: Low+middle subsample ---
print("\n--- Test E: Z×KAOPEN on LOW+MIDDLE income subsample only ---")
lm_only = comp[comp['income_group'] != 'high'].copy()
print(f"  Low+middle subsample: {lm_only['iso3'].nunique()} countries, {len(lm_only)} obs")

gls_lm_base = PanelGLS()
gls_lm_base.fit(lm_only['ca_gdp'].values, lm_only[vars_high_sub].values,
                lm_only['iso3'].values, lm_only['year'].values)

gls_lm_ext = PanelGLS()
gls_lm_ext.fit(lm_only['ca_gdp'].values, lm_only[vars_high_sub_ext].values,
               lm_only['iso3'].values, lm_only['year'].values)

n_lm = gls_lm_ext.n_obs
k_lm = len(vars_high_sub_ext)
F_lm = ((gls_lm_ext.r_squared - gls_lm_base.r_squared) / 3) / ((1 - gls_lm_ext.r_squared) / (n_lm - k_lm - 1))
p_lm = 1 - stats.f.cdf(F_lm, 3, n_lm - k_lm - 1)
print(f"  Baseline R² = {gls_lm_base.r_squared:.4f}")
print(f"  + Z×KAOPEN R² = {gls_lm_ext.r_squared:.4f}")
print(f"  F(3,{n_lm-k_lm-1}) = {F_lm:.3f}, p = {p_lm:.4f}")

for i, v in enumerate(vars_high_sub_ext):
    if 'x_kaopen' in v:
        sig = '***' if gls_lm_ext.pvalues[i] < 0.001 else '**' if gls_lm_ext.pvalues[i] < 0.01 else '*' if gls_lm_ext.pvalues[i] < 0.05 else ''
        print(f"  {v}: {gls_lm_ext.beta[i]:.3f} (p={gls_lm_ext.pvalues[i]:.4f}) {sig}")


# ═══════════════════════════════════════════════════════════════════════════
# 2. GE CLEARING DECOMPOSITION
# ═══════════════════════════════════════════════════════════════════════════
print("\n\n" + "=" * 70)
print("GE CLEARING DECOMPOSITION: WHY 6-12% vs 45-67%?")
print("=" * 70)

# Load 140-country baseline coefficients
coeffs = pd.read_csv(OUTPUT_DIR / "regression_baseline_demo_plus_eba_140.csv")
coeff_map = dict(zip(coeffs['variable'], coeffs['coefficient']))
z_names = ['Z_1', 'Z_2', 'Z_3']
z_betas = {z: coeff_map[z] for z in z_names}

# Load original 69-country coefficients
orig_coeffs_path = PROJECT_DIR / "output" / "tables" / "regression_baseline_demo_plus_eba.csv"
if orig_coeffs_path.exists():
    orig_coeffs = pd.read_csv(orig_coeffs_path)
    orig_coeff_map = dict(zip(orig_coeffs['variable'], orig_coeffs['coefficient']))
    orig_z_betas = {z: orig_coeff_map.get(z, 0) for z in z_names}
    print(f"\n  Original 69c Z betas: {orig_z_betas}")
    print(f"  Expanded 140c Z betas: {z_betas}")
    print(f"  Coefficient ratio: Z1={z_betas['Z_1']/orig_z_betas['Z_1']:.2f}x, "
          f"Z2={z_betas['Z_2']/orig_z_betas['Z_2']:.2f}x, "
          f"Z3={z_betas['Z_3']/orig_z_betas['Z_3']:.2f}x")

# GDP weights
gdp_data = panel[['iso3', 'year', 'ngdp_usd']].dropna()
latest_year = gdp_data['year'].max()
gdp_weights = gdp_data[gdp_data['year'] == latest_year][['iso3', 'ngdp_usd']].copy()
total_gdp = gdp_weights['ngdp_usd'].sum()
gdp_weights['weight'] = gdp_weights['ngdp_usd'] / total_gdp
weight_map = dict(zip(gdp_weights['iso3'], gdp_weights['weight']))

# Classify countries
all_countries_69 = set(EBA_COUNTRIES + SSA_COUNTRIES)
expansion_countries = set(EU_EXPANSION + EXPANSION_TIER1)

# Compute PE contributions by country group for key years
proj_years = [2025, 2030, 2035, 2040, 2050, 2060]
print(f"\n--- PE Imbalance Decomposition by Country Group ---")

for year in proj_years:
    yr_polys = polys[polys['year'] == year]

    groups = {
        'EBA-49': set(EBA_COUNTRIES),
        'SSA (orig 20)': set(SSA_COUNTRIES),
        'EU expansion': set(EU_EXPANSION),
        'Other expansion': expansion_countries - set(EU_EXPANSION),
    }

    total_pe = 0
    total_weight = 0
    group_contributions = {}

    for grp_name, grp_countries in groups.items():
        grp_pe = 0
        grp_weight = 0
        for iso3 in grp_countries:
            w = weight_map.get(iso3, 0)
            if w == 0:
                continue
            yr = yr_polys[yr_polys['iso3'] == iso3]
            if len(yr) == 0:
                continue
            demo_ca = sum(z_betas[zv] * yr[zv].values[0] for zv in z_names)
            grp_pe += w * demo_ca
            grp_weight += w

        group_contributions[grp_name] = (grp_pe, grp_weight)
        total_pe += grp_pe
        total_weight += grp_weight

    if year in [2025, 2040, 2050]:
        print(f"\n  Year {year} (total PE imbalance = {total_pe/total_weight:.3f} pp):")
        for grp_name, (grp_pe, grp_weight) in group_contributions.items():
            pct = grp_pe / (total_pe if abs(total_pe) > 0.001 else 0.001) * 100
            print(f"    {grp_name:20s}: PE contrib = {grp_pe/total_weight:+.3f} pp  "
                  f"(GDP weight = {grp_weight:.3f}, share of imbalance: {pct:+.1f}%)")

# --- Compare: what would happen with original 69-country sample + 140c coefficients? ---
print(f"\n\n--- Counterfactual: Original 69 countries with 140c coefficients ---")
for year in proj_years:
    yr_polys = polys[polys['year'] == year]

    pe_69 = 0
    w_69 = 0
    pe_140 = 0
    w_140 = 0

    for iso3 in polys['iso3'].unique():
        w = weight_map.get(iso3, 0)
        if w == 0:
            continue
        yr = yr_polys[yr_polys['iso3'] == iso3]
        if len(yr) == 0:
            continue
        demo_ca = sum(z_betas[zv] * yr[zv].values[0] for zv in z_names)

        pe_140 += w * demo_ca
        w_140 += w

        if iso3 in all_countries_69:
            pe_69 += w * demo_ca
            w_69 += w

    imb_69 = pe_69 / w_69 if w_69 > 0 else 0
    imb_140 = pe_140 / w_140 if w_140 > 0 else 0
    print(f"  {year}: 69-country PE = {imb_69:+.3f} pp | 140-country PE = {imb_140:+.3f} pp | "
          f"diff = {imb_140 - imb_69:+.3f} pp")

# --- What if we use ORIGINAL coefficients on 140 countries? ---
print(f"\n--- Counterfactual: 140 countries with ORIGINAL 69c coefficients ---")
if orig_coeffs_path.exists():
    for year in proj_years:
        yr_polys = polys[polys['year'] == year]

        pe_orig_coef = 0
        pe_new_coef = 0
        w_tot = 0

        for iso3 in polys['iso3'].unique():
            w = weight_map.get(iso3, 0)
            if w == 0:
                continue
            yr = yr_polys[yr_polys['iso3'] == iso3]
            if len(yr) == 0:
                continue

            demo_ca_orig = sum(orig_z_betas[zv] * yr[zv].values[0] for zv in z_names)
            demo_ca_new = sum(z_betas[zv] * yr[zv].values[0] for zv in z_names)

            pe_orig_coef += w * demo_ca_orig
            pe_new_coef += w * demo_ca_new
            w_tot += w

        imb_orig = pe_orig_coef / w_tot
        imb_new = pe_new_coef / w_tot
        print(f"  {year}: orig coefs = {imb_orig:+.3f} pp | new coefs = {imb_new:+.3f} pp | "
              f"coef effect = {imb_new - imb_orig:+.3f} pp")

# --- Top 10 contributors to PE surplus and deficit at 2050 ---
print(f"\n--- Top Contributors to PE Imbalance at 2050 ---")
yr_polys = polys[polys['year'] == 2050]
country_pe = []
for iso3 in polys['iso3'].unique():
    w = weight_map.get(iso3, 0)
    if w == 0:
        continue
    yr = yr_polys[yr_polys['iso3'] == iso3]
    if len(yr) == 0:
        continue
    demo_ca = sum(z_betas[zv] * yr[zv].values[0] for zv in z_names)
    country_pe.append({'iso3': iso3, 'weight': w, 'demo_ca': demo_ca,
                       'weighted_ca': w * demo_ca})

cpe = pd.DataFrame(country_pe)
total_w = cpe['weight'].sum()
cpe['weighted_ca_norm'] = cpe['weighted_ca'] / total_w

print(f"\n  Top 10 SURPLUS contributors (weighted):")
for _, r in cpe.nlargest(10, 'weighted_ca_norm').iterrows():
    print(f"    {r['iso3']:>4}: demo_ca = {r['demo_ca']:+.2f} pp, "
          f"GDP weight = {r['weight']:.4f}, "
          f"contribution to global PE = {r['weighted_ca_norm']:+.3f} pp")

print(f"\n  Top 10 DEFICIT contributors (weighted):")
for _, r in cpe.nsmallest(10, 'weighted_ca_norm').iterrows():
    print(f"    {r['iso3']:>4}: demo_ca = {r['demo_ca']:+.2f} pp, "
          f"GDP weight = {r['weight']:.4f}, "
          f"contribution to global PE = {r['weighted_ca_norm']:+.3f} pp")

global_surplus = cpe[cpe['weighted_ca_norm'] > 0]['weighted_ca_norm'].sum()
global_deficit = cpe[cpe['weighted_ca_norm'] < 0]['weighted_ca_norm'].sum()
print(f"\n  Total surplus side: {global_surplus:+.3f} pp")
print(f"  Total deficit side: {global_deficit:+.3f} pp")
print(f"  Net PE imbalance:   {global_surplus + global_deficit:+.3f} pp")
print(f"  Gross imbalance:    {global_surplus - global_deficit:+.3f} pp")

# --- Rate channel capacity analysis ---
print(f"\n\n--- Rate Channel Capacity Analysis ---")
delta_rate = 0.127
print(f"  Rate semi-elasticity δ = {delta_rate}")
print(f"  Max Δr* = ±2 pp")
print(f"  Max rate channel absorption per country = δ × 2 = {delta_rate * 2:.3f} pp of CA")
print(f"  For the rate channel to clear X pp of global PE imbalance:")
print(f"    Need Δr* = X / δ = X / {delta_rate}")
for imb in [0.5, 1.0, 2.0, 3.0, 4.0, 5.0]:
    dr = imb / delta_rate
    capped = min(dr, 2.0)
    cleared = (delta_rate * capped) / imb * 100
    print(f"    PE imbalance = {imb:.1f} pp → uncapped Δr* = {dr:.1f} pp → "
          f"capped Δr* = {capped:.1f} pp → clears {cleared:.0f}%")

# --- Why is PE imbalance so much larger? ---
print(f"\n\n--- Diagnosis: Why is PE imbalance larger with 140 countries? ---")
print(f"\n  Factor 1: COEFFICIENT SIZE")
print(f"    140c Z₁ is {z_betas['Z_1']/orig_z_betas['Z_1']:.1f}x the 69c Z₁")
print(f"    This mechanically amplifies ALL PE projections")

print(f"\n  Factor 2: ASYMMETRIC SAMPLE COMPOSITION")
yr_polys_2050 = polys[polys['year'] == 2050]
n_surplus = len(cpe[cpe['demo_ca'] > 0])
n_deficit = len(cpe[cpe['demo_ca'] < 0])
w_surplus = cpe[cpe['demo_ca'] > 0]['weight'].sum()
w_deficit = cpe[cpe['demo_ca'] < 0]['weight'].sum()
print(f"    Surplus countries: {n_surplus} (GDP weight: {w_surplus:.3f})")
print(f"    Deficit countries: {n_deficit} (GDP weight: {w_deficit:.3f})")
print(f"    Many deficit countries are small (low GDP weight) so don't offset")

# Also compute: what share of AE GDP is in surplus?
ae_countries = set(EBA_COUNTRIES) - set(['CHN', 'IND', 'IDN', 'BRA', 'RUS', 'ZAF',
                                          'MEX', 'TUR', 'COL', 'ARG', 'PER', 'CHL',
                                          'EGY', 'PAK', 'MAR', 'PHL', 'THA', 'MYS',
                                          'SAU', 'ARE'])
ae_pe = cpe[cpe['iso3'].isin(ae_countries)]
print(f"\n    Advanced economies at 2050:")
print(f"      N = {len(ae_pe)}, GDP weight = {ae_pe['weight'].sum():.3f}")
print(f"      All in surplus: {(ae_pe['demo_ca'] > 0).all()}")
print(f"      Mean demo_ca = {ae_pe['demo_ca'].mean():+.2f} pp")
print(f"      Weighted contribution = {ae_pe['weighted_ca_norm'].sum():+.3f} pp")

# Save summary
summary = {
    'test': ['High-income joint F', 'Middle-income joint F', 'Low-income joint F',
             'High subsample joint F', 'Low+Mid subsample joint F'],
    'F_stat': [F_high, F_mid, F_low, F_h, F_lm],
    'p_value': [p_high, p_mid, p_low, p_h, p_lm],
}
pd.DataFrame(summary).to_csv(OUTPUT_DIR / "kaopen_income_joint_tests.csv", index=False)

print(f"\n\nSaved joint test results to {OUTPUT_DIR / 'kaopen_income_joint_tests.csv'}")
print("\nDone.")
